{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 第 9 章：循环神经网络 — PyTorch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 导入 MNIST 数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor()])\n",
    "\n",
    "# 训练集\n",
    "trainset = torchvision.datasets.MNIST(root='./datasets/ch08/pytorch',     # 选择数据的根目录\n",
    "                                      train=True,\n",
    "                                      download=True,    # 不从网络上download图片\n",
    "                                      transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,\n",
    "                                         shuffle=True, num_workers=2)\n",
    "# 测试集\n",
    "testset = torchvision.datasets.MNIST(root='./datasets/ch08/pytorch',     # 选择数据的根目录\n",
    "                                     train=False,\n",
    "                                     download=True,    # 不从网络上download图片\n",
    "                                     transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=4,\n",
    "                                        shuffle=False, num_workers=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset MNIST\n",
      "    Number of datapoints: 60000\n",
      "    Split: train\n",
      "    Root Location: ./datasets/ch08/pytorch\n",
      "    Transforms (if any): Compose(\n",
      "                             ToTensor()\n",
      "                         )\n",
      "    Target Transforms (if any): None\n"
     ]
    }
   ],
   "source": [
    "print(trainset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset MNIST\n",
      "    Number of datapoints: 10000\n",
      "    Split: test\n",
      "    Root Location: ./datasets/ch08/pytorch\n",
      "    Transforms (if any): Compose(\n",
      "                             ToTensor()\n",
      "                         )\n",
      "    Target Transforms (if any): None\n"
     ]
    }
   ],
   "source": [
    "print(testset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来展示一些训练样本图像"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXYAAAB6CAYAAACr63iqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAEvJJREFUeJzt3XnQVNWZx/HvIyoiBgUEZIsrcY1L3HDGmLijg2JcEgxREqlgEo06RaJoKjFWWYkiruVWVFSIoqi4EYwwxNE4RGVEJAZF5EUdZESJcQPHoOgzf/S9h/NCN728/d6m7/v7VL31Pn26+95zuc15T5977nPM3RERkfzYpNEVEBGR+lLDLiKSM2rYRURyRg27iEjOqGEXEckZNewiIjmjhl1EJGfa1LCb2RAzW2RmLWY2tl6VEhGR2lmtNyiZWSfgVeBoYBnwHHC6u79cv+qJiEi1Nm3Dew8CWtz9NQAzmwIMA0o27Gam21xFRKr3rrv3qvTFbRmK6Q+8GT1elpS1YmajzWyumc1tw75ERDqy/6nmxW3psVuRsvV65O4+AZgA6rGLiGShLT32ZcDA6PEA4K22VUdERNqqLQ37c8AgM9vRzDYHhgPT6lMtERGpVc1DMe6+xszOBWYCnYDb3f2lutVMRERqUvN0x5p2pjF2EZFaPO/uB1T6Yt15KiKSM2rYRURyRg27iEjOqGEXEckZNewiIjmjhl1EJGfUsIuI5IwadhGRnFHDLiKSM2rYRURyRg27iEjOqGEXEcmZtiy0ISKy0TvssMMAePHFF0PZBx980KjqZEI9dhGRnFHDLiKSMxqKATbffPMQ77777iHeZ599QvyTn/xkg9uYMGFCiGfMmAHAW29ppUBpu1691i5O/4Mf/ACAk08+OZQdfPDBIb777rtDfOONNwLwzDPPtHcVNwpbb711iMeNGxfiUaNGAXD11VeHsosuuii7ijWAeuwiIjmjhl1EJGc69NJ46de1+KvumWee2ebtvvDCCwAcf/zxoWzFihVt3m49DRgwIMTnn38+ACeccEIo+8pXvlL0ff/4xz8AmDx5ctHnP/rooxBPnDix4vqkw1b//Oc/K35Pnh1yyCEhvuGGG0K8//77V7yN9957D4BvfOMboeyll/K1LHGfPn1CfMstt4R42LBhIb7vvvsAOPfcc0NZ+jluIvVdGs/MbjezFWa2ICrrYWazzGxx8rt7rbUVEZH6KttjN7PDgFXA7919r6RsHPCeu19hZmOB7u5e9mpEo3rsxx57bIh79+4d4vQvfJcuXdplv3Hv6KSTTgrxa6+91i77q0Y8pzftnf/1r38NZQccUHHnoJVNNlnbV/jiiy8qft/tt98OwBVXXBHKlixZUlMdmtW2224b4kcffTTEBx54YIhbWloA+PGPf1x0G0ceeWSIx44dC8Dw4cNDWdp7bXZpT/2Pf/xjKPvqV78a4tmzZ4f4uOOOA2D16tUZ1a5d1LfH7u5PAe+tUzwMmJTEk4CTEBGRjUKt0x37uPtyAHdfbma9S73QzEYDo2vcj4iIVKmii6dmtgMwPRqK+cDdt4mef9/dy46zZzEUk35FS79+AVx77bUh7tatW3tXoah4WOa0004L8RtvvAFk/zVxl112CfGmmxb+vi9dujSU7bTTTiHebLPNQlxqCKCYPffcE4A1a9aEskMPPXSD75k2bVqIv/Wtb1W8r2aWDsHEF/9OOeWUEKf3RcTln3zySdFt7bXXXiFOL+IPGTIklMUXt1955ZUQr1y5sqa6Zyn+PKRDdl27dg1l8dz1+IJzTtR3KKaEd8ysL0Dye+Oa8iEi0oHV2rBPA0Ym8UjgkfpUR0RE2qqSWTH3AN8EtgXeAS4FHgbuA74MLAVOc/d1L7AW21a7D8WkX8HOOeec9t5VXVx++eUAXHrppQ2uSf2lQzjxZ2ybbcIIHpdddlmIf/SjHwHw2WefhbJ42Gbu3LntVs9G6NSpU4inTp0KtJ57/dRTT4V46NChIV61atUGt9u5c+cQb7/99gCceOKJoew3v/lNiK+77roQX3jhhRXXvb317ds3xOPHjw/xqaeeGuI//OEPAPzyl78MZQsXLsygduuLZ9r16NEjxPG9K+k9BW1Q1VBM2Yun7n56iaeOLFEuIiINpJQCIiI507TZHeNZGw888ECId9ttt6q39dBDD4X4rLPOKvqan/70p0DrmyDKiWfmbLXVVkVfc8wxxwD5HIqJh1VS7777bojjYz766KMB2HnnnUNZfGNZ3oZi7rjjjhCnQzAzZ84MZRdccEGIyw2/xOLZVeeddx5QOjPp+++/X/F2szRv3rwQ9+zZM8Q//OEPQ3zvvfcC2aegiGe0pTcdfv3rXw9l/fv3D/Hrr78e4jFjxgDwyCPZXI5Uj11EJGeaNglY/Nf71ltvrfh9cS8lzW399NNPh7J6Jgf6y1/+EuLBgwcXfc2HH34IwIgRI0LZY489Vrc6NIs0Z/hBBx0UyhYsCOmJWuXGb1bxt8z4G0h6QTlO8JXOQa9WvC5A+u0zTvNwzTXXhDi+YFpN+of2Ftfl9NPXXuJLe+ntKf1mHSfBixMDxhdwly9fDsBzzz0XyuJUHf369QvxJZdcAqy9Z6QGmcxjFxGRjZQadhGRnGnai6fxnN9y4gt28fueffbZutZpXRdffHGIn3jiiaKvSZfzii9wdcShmI7gqKOOCnE8n//hhx8GWmfXLCeer37llVeG+Pvf/36I0yGYeKgyzfgIG9fwSylZ5I+PU2b8+te/BuDss88OZZMmTQrx4YcfHuLFixdvcLvpthpBPXYRkZxRwy4ikjNNOxQTLztXbmZPOvsF2n/4JRavil5OfDyST+kMqHWl89dLDY3Ec6MPPvhgAE4++eRQ9t3vfrfo+2bNmgWsTVsBrTNtNoPtttsuxPEsqbaKs5vG2V+/9KUvAa2Xjix13oqJM1CmM2GgdSqHLKjHLiKSM2rYRURypmmHYswsxMWGYuJZKHPmzGn3+sQ3gXzve98DYNdddy37vvS2+zjTnuRTuvDIutJb0+PP0N577x3iM844I8RbbrnlBveR3jQDa2+seeedd6qvbAPFqSgmTpwY4nh4pBbxIiRPPvlkiF999dUQp2lAah1+ufPOO0Mcry/729/+tqq6tpV67CIiOdO0PfZyF0yXLVsW4nqmCSglvvU5TvBUTnrreDy/uKPYb7/9Qrzvvvuu9/z8+fOzrE67iz8X8XoBabKzOOlZreI0Ac3WU0+l32AAJk+eHOLhw4eHeMqUKVVvN76HJb61P04ZUK6nHiepS5eJjM9lnFBw5MiRIf7888+rrm9bqMcuIpIzathFRHKmaYdiyl08zUK6ujy0zjZZjSxumd6YxEvCxRed0lvk4/QPV111VXYVy0Ccnzue1/zzn/8cgLfffjuUPfrooyGOM0GmF+fii6gPPvhgiO+///461rgx4pQa99xzT4hvuummEO+xxx5A60kH5Zaf+/Of/xziU045JcTxZy4dHhw0aFAoO+KII0L8ne98J8RLliwBWudo/9Of/hTirIdfYmV77GY20MyeMLOFZvaSmZ2flPcws1lmtjj53b39qysiIuVUMhSzBhjj7rsDg4FzzGwPYCzwuLsPAh5PHouISINVvdCGmT0C3Jj8fNPdl5tZX+BJd9/gxO16LrQR17vYrdhxUv44c+Inn3wS4i222AJovbxWly5diu6vW7duQOvhgT59+oQ4Xh6rXH3TbH4Ao0aNAqqbN9vM4n+zt956a73n4yGKeAX6PEs/W59++mkoiz+T8Tz2dF73ypUrQ1nv3r1DHG8jb+K0G+n/748//jiUxQtexNLhq3QZSmidBTOeQTdw4MD13h9vN16oJJ0Lv2LFikqq31ZVLbRR1Ri7me0A7AfMAfq4+3KApHHvXeI9o4HR1exHRERqV3HDbmZbAQ8AF7j7R/HFyw1x9wnAhGQbjbnKKSLSgVQ0FGNmmwHTgZnufk1StogGDsXMmDEjxOkK95WYPn16iNOFD+KyeE3Deoq/Wnft2rVd9tEMNBRTmZ49e4Y4XtAhHbb51a9+Fcqyzhy4MejevTBXI76pK854WeyGt7gzOnv27BAvWrRovdfGs4vioZrVq1fXWOM2q++ap1b417gNWJg26olpQHpr1UjgkWpqKSIi7aNsj93MDgX+C/gbkF6lvITCOPt9wJeBpcBp7r7BiaT17LEPHTo0xHfddVeI03zKjZT+hR8xYkQoi+cot7S0ZF6njUW8XFixHnl88apYj76j+NnPfhbicePGhThNQbH//vtnXidpqPpePHX32UCpAfUjK92RiIhkQykFRERypmlTCsQXPOPMimmmtTSvcr2kQ1bxxZOXX345xNdff/165fPmzatrHfLghBNOKFqe3hZf7rbwjiK9SLqu559/PuOaSDNSj11EJGfUsIuI5EzTDsXE4mxwaXz11VeHsm9/+9sh7tevX8Xbve2220KcLlrQkedW1ype4CBe8i02fvx4oPV8/44sznwZW7BgQcY1kWakHruISM6oYRcRyZlcDMUUM2bMmBDffPPNIS63ynts4cKFIV6zZk19KtYB9erVK8SbbLK2LxHfgNSRb0YSqTf12EVEcia3PfZYuoSVNEa8hGBs/vz5IX7zzTezqs5Ga8CAASGOc6yLVEs9dhGRnFHDLiKSMx1iKEYaK12CcF2TJ0/OuCYbt3h5xFWrVoU4vvgc37MhUop67CIiOaOGXUQkZzQUI+1u5syZIT7wwANDHGfoFFi5cmWId9555wbWRJqdeuwiIjmjhl1EJGcqWfN0C+ApoDOFoZup7n6pme0ITAF6APOAM9z90zLbqtuapyIiHUhVa55W0mNfDRzh7vsA+wJDzGwwcCVwrbsPAt4HRtVSWxERqa+yDbsXpJNqN0t+HDgCmJqUTwJOapcaiohIVSoaYzezTmY2H1gBzAKWAB+4e5rycBnQv32qKCIi1aioYXf3z919X2AAcBCwe7GXFXuvmY02s7lmNrf2aoqISKWqmhXj7h8ATwKDgW3MLJ0HPwAomlDb3Se4+wHVDPyLiEjtyjbsZtbLzLZJ4i7AUcBC4Ang1ORlI4FH2quSIiJSuUruPO0LTDKzThT+ENzn7tPN7GVgipldDrwA3LahjYiISDbKzmOv687M/g58DLyb2U6ztS06tmakY2tOHenYtnf3XqVevK5MG3YAM5ub1/F2HVtz0rE1Jx1baUopICKSM2rYRURyphEN+4QG7DMrOrbmpGNrTjq2EjIfYxcRkfaloRgRkZxRwy4ikjOZNuxmNsTMFplZi5mNzXLf9WZmA83sCTNbaGYvmdn5SXkPM5tlZouT390bXddaJInfXjCz6cnjHc1sTnJc95rZ5o2uYy3MbBszm2pmryTn7pAcnbN/Tz6LC8zsHjPbolnPm5ndbmYrzGxBVFb0PFnBDUm78qKZfa1xNS+vxLFdlXwmXzSzh9K7/ZPnLk6ObZGZHVvJPjJr2JM7V28CjgP2AE43sz2y2n87WAOMcffdKeTOOSc5nrHA40me+seTx83ofAqpI1J5yb9/PTDD3XcD9qFwjE1/zsysP3AecIC77wV0AobTvOdtIjBknbJS5+k4YFDyMxq4JaM61moi6x/bLGAvd98beBW4GCBpU4YDeybvuTlpSzcoyx77QUCLu7+WrLQ0BRiW4f7ryt2Xu/u8JF5JoYHoT+GYJiUva8o89WY2APg34HfJYyMH+ffNrBtwGEn6C3f/NEls1/TnLLEp0CVJzrclsJwmPW/u/hTw3jrFpc7TMOD3ydoRz1JIUNg3m5pWr9ixuft/RGnQn6WQWBEKxzbF3Ve7++tAC4W2dIOybNj7A29Gj3OTw93MdgD2A+YAfdx9ORQaf6B342pWs+uAC4Evksc9yUf+/Z2AvwN3JMNMvzOzruTgnLn7/wLjgaUUGvQPgefJx3lLlTpPeWtbzgIeS+Kaji3Lht2KlDX9XEsz2wp4ALjA3T9qdH3aysyGAivc/fm4uMhLm/HcbQp8DbjF3fejkLeo6YZdiknGm4cBOwL9gK4UhijW1YznrZy8fD4xs19QGOadnBYVeVnZY8uyYV8GDIwel8zh3izMbDMKjfpkd38wKX4n/RqY/F7RqPrV6F+BE83sDQrDZUdQ6MFXlH9/I7cMWObuc5LHUyk09M1+zqCQTvt1d/+7u38GPAj8C/k4b6lS5ykXbYuZjQSGAiN87Q1GNR1blg37c8Cg5Cr95hQuCEzLcP91lYw73wYsdPdroqemUchPD02Yp97dL3b3Ae6+A4Vz9J/uPoIc5N9397eBN81s16ToSOBlmvycJZYCg81sy+SzmR5b05+3SKnzNA04M5kdMxj4MB2yaRZmNgS4CDjR3f8vemoaMNzMOpvZjhQuEP932Q26e2Y/wPEUrvguAX6R5b7b4VgOpfCV6EVgfvJzPIXx6MeBxcnvHo2uaxuO8ZvA9CTeKflAtQD3A50bXb8aj2lfYG5y3h4GuuflnAGXAa8AC4A7gc7Net6AeyhcK/iMQq91VKnzRGG44qakXfkbhZlBDT+GKo+thcJYetqW3Bq9/hfJsS0CjqtkH0opICKSM7rzVEQkZ9Swi4jkjBp2EZGcUcMuIpIzathFRHJGDbuISM6oYRcRyZn/BwyyzRo47o6iAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "          0           7           9           6\n"
     ]
    }
   ],
   "source": [
    "def imshow(img):\n",
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "\n",
    "# 选择一个 batch 的图片\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "# 显示图片\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "plt.show()\n",
    "# 打印 labels\n",
    "print(' '.join('%11s' % labels[j].numpy() for j in range(4)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义循环神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.rnn = nn.LSTM(          # 使用 LSTM 结构\n",
    "            input_size = 28,         # 输入每个元素的维度，即图片每行包含 28 个像素点\n",
    "            hidden_size = 84,        # 隐藏层神经元设置为 84 个\n",
    "            num_layers=2,            # 隐藏层数目，两层\n",
    "            batch_first=True,        # 是否将 batch 放在维度的第一位，(batch, time_step, input_size)\n",
    "        )\n",
    "        self.out = nn.Linear(84, 10) # 输出层，包含 10 个神经元，对应 0～9 数字\n",
    "\n",
    "    def forward(self, x):\n",
    "        r_out, (h_n, h_c) = self.rnn(x, None)   \n",
    "        # 选择图片的最后一行作为 RNN 输出\n",
    "        out = self.out(r_out[:, -1, :])\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Net(\n",
      "  (rnn): LSTM(28, 84, num_layers=2, batch_first=True)\n",
      "  (out): Linear(in_features=84, out_features=10, bias=True)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "net = Net()\n",
    "print(net) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 定义损失函数与优化算法"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(net.parameters(), lr=0.0001)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[epoch: 1, mini-batch:  2000] loss: 1.592\n",
      "[epoch: 1, mini-batch:  4000] loss: 0.789\n",
      "[epoch: 1, mini-batch:  6000] loss: 0.564\n",
      "[epoch: 1, mini-batch:  8000] loss: 0.429\n",
      "[epoch: 1, mini-batch: 10000] loss: 0.357\n",
      "[epoch: 1, mini-batch: 12000] loss: 0.294\n",
      "[epoch: 1, mini-batch: 14000] loss: 0.290\n",
      "[epoch: 2, mini-batch:  2000] loss: 0.236\n",
      "[epoch: 2, mini-batch:  4000] loss: 0.213\n",
      "[epoch: 2, mini-batch:  6000] loss: 0.196\n",
      "[epoch: 2, mini-batch:  8000] loss: 0.195\n",
      "[epoch: 2, mini-batch: 10000] loss: 0.173\n",
      "[epoch: 2, mini-batch: 12000] loss: 0.161\n",
      "[epoch: 2, mini-batch: 14000] loss: 0.161\n",
      "[epoch: 3, mini-batch:  2000] loss: 0.140\n",
      "[epoch: 3, mini-batch:  4000] loss: 0.124\n",
      "[epoch: 3, mini-batch:  6000] loss: 0.141\n",
      "[epoch: 3, mini-batch:  8000] loss: 0.133\n",
      "[epoch: 3, mini-batch: 10000] loss: 0.121\n",
      "[epoch: 3, mini-batch: 12000] loss: 0.129\n",
      "[epoch: 3, mini-batch: 14000] loss: 0.107\n",
      "[epoch: 4, mini-batch:  2000] loss: 0.110\n",
      "[epoch: 4, mini-batch:  4000] loss: 0.099\n",
      "[epoch: 4, mini-batch:  6000] loss: 0.101\n",
      "[epoch: 4, mini-batch:  8000] loss: 0.101\n",
      "[epoch: 4, mini-batch: 10000] loss: 0.094\n",
      "[epoch: 4, mini-batch: 12000] loss: 0.087\n",
      "[epoch: 4, mini-batch: 14000] loss: 0.098\n",
      "[epoch: 5, mini-batch:  2000] loss: 0.082\n",
      "[epoch: 5, mini-batch:  4000] loss: 0.094\n",
      "[epoch: 5, mini-batch:  6000] loss: 0.082\n",
      "[epoch: 5, mini-batch:  8000] loss: 0.082\n",
      "[epoch: 5, mini-batch: 10000] loss: 0.081\n",
      "[epoch: 5, mini-batch: 12000] loss: 0.078\n",
      "[epoch: 5, mini-batch: 14000] loss: 0.078\n"
     ]
    }
   ],
   "source": [
    "num_epoches = 5    # 设置 epoch 数目\n",
    "cost = []     # 损失函数累加\n",
    "\n",
    "for epoch in range(num_epoches):    \n",
    "\n",
    "    running_loss = 0.0\n",
    "    for i, data in enumerate(trainloader, 0):\n",
    "        # 输入样本和标签\n",
    "        inputs, labels = data\n",
    "        inputs = inputs.view(-1, 28, 28)  # 设置 RNN 输入维度为 (batch, time_step, input_size)\n",
    "\n",
    "        # 每次训练梯度清零\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # 正向传播、反向传播和优化过程\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # 打印训练情况\n",
    "        running_loss += loss.item()\n",
    "        if i % 2000 == 1999:    # 每隔2000 mini-batches，打印一次\n",
    "            print('[epoch: %d, mini-batch: %5d] loss: %.3f' % \n",
    "                 (epoch + 1, i + 1, running_loss / 2000))\n",
    "            cost.append(running_loss / 2000)\n",
    "            running_loss = 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEKCAYAAAAW8vJGAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAIABJREFUeJzt3Xt8nGWZ//HPlUkmSXNom0malpaSlrbQUpFDKCDIQYoWdGF1OVhFxVPdXVnXdd1V9+eqP/y56+HnaVcUWUTQVQ4rqBU5SOWkYLFBgbbpkXIKbZO0CW3SnJNr/3iepJM0p2kymUnm+3695jXzPPMcrsyrnWvu+37u6zF3R0REpFdWqgMQEZH0osQgIiL9KDGIiEg/SgwiItKPEoOIiPSjxCAiIv0oMYiISD9KDCIi0o8Sg4iI9JOd6gCORmlpqVdUVKQ6DBGRSeXpp5/e5+5lI203KRNDRUUFVVVVqQ5DRGRSMbOXRrOdupJERKQfJQYREelHiUFERPpRYhARkX6SmhjM7BYzqzOzTcNsc4GZPWNmm83ssWTGIyIiI0t2i+FWYNVQb5rZDOC7wGXufhJwZZLjERGRESQ1Mbj740DDMJu8C7jH3V8Ot69LZjwiIjKyVI8xLAFmmtmjZva0mb13qA3NbI2ZVZlZVX19/VGd7E8vN/KVB7bS2d1ztPGKiEx5qU4M2cDpwFuBtwD/amZLBtvQ3W9y90p3rywrG3Hi3qA27z7I9x59nsZDHUcdsIjIVJfqxFADPODuh9x9H/A48Ppknay0IArAvmYlBhGRoaQ6MfwSeKOZZZvZNOBMYEuyThYrzAVg/6H2ZJ1CRGTSS2qtJDO7HbgAKDWzGuDzQA6Au9/o7lvM7AHgOaAHuNndh7y0daxihUGLoUFdSSIiQ0pqYnD31aPY5mvA15IZR6+YupJEREaU6q6kCVWcl0N2lrG/WV1JIiJDyajEkJVllBRE2a8Wg4jIkDIqMUAwAK3BZxGRoWVcYigtjGqMQURkGBmXGGIFUV2VJCIyjIxLDCUFuRp8FhEZRsYlhlhhlEMd3bR2dKc6FBGRtJRxiaE0nOSmAWgRkcFlXGKIFYRlMTQALSIyqMxLDGoxiIgMK+MSQ2mhWgwiIsPJuMRQUtDbYlBiEBEZTMYlhmnRCHk5WbpkVURkCBmXGMyMWEGuupJERIaQcYkBwrIY6koSERlURiaGWKFmP4uIDCUzE4NKb4uIDCmpicHMbjGzOjMb9nadZnaGmXWb2RXJjKdXrDCXhkMduPtEnE5EZFJJdovhVmDVcBuYWQT4CvBgkmPpEyuI0tHdQ1N710SdUkRk0khqYnD3x4GGETb7O+BuoC6ZscTrm/2s7iQRkSOkdIzBzOYCbwduHMW2a8ysysyq6uvrx3TeWN/sZw1Ai4gMlOrB528Bn3L3EWtgu/tN7l7p7pVlZWVjOmksnP2sO7mJiBwpO8XnrwTuMDOAUuBSM+ty918k86R99ZJUSE9E5AgpTQzuvqD3tZndCtyb7KQAh+slNajFICJyhKQmBjO7HbgAKDWzGuDzQA6Au484rpAs0ewsivKyVUhPRGQQSU0M7r46gW2vTWIoRygtzGWfBp9FRI6Q6sHnlNHsZxGRwWVuYiiMavBZRGQQGZwYVHpbRGQwGZsYSguiNLZ00N2jekkiIvEyNjHECnPpcXitRa0GEZF4GZsYdO9nEZHBZWxi6C2kp0tWRUT6y9jE0FcWQwPQIiL9ZGxi6C2kpwqrIiL9ZWximDEtSpZpjEFEZKCMTQyRLKOkIKrEICIyQMYmBgiuTFJXkohIfxmdGGIFmv0sIjJQZieGQnUliYgMlNGJQaW3RUSOlNGJIVYQpamti/auEW85LSKSMTI7MYST3BoPdaY4EhGR9JHUxGBmt5hZnZltGuL9d5vZc+HjSTN7fTLjGai3XpK6k0REDkt2i+FWYNUw778AnO/uJwNfBG5Kcjz9lBaqkJ6IyEDJvufz42ZWMcz7T8YtrgfmJTOegWJ99ZLUYhAR6ZVOYwwfBO4f6k0zW2NmVWZWVV9fPy4n7K2wqrkMIiKHpUViMLMLCRLDp4baxt1vcvdKd68sKysbl/MW5WYTjWSxT/d+FhHpk9SupNEws5OBm4FL3H3/BJ+bWGGUBrUYRET6pLTFYGbzgXuA97j79lTEoNnPIiL9JbXFYGa3AxcApWZWA3weyAFw9xuBzwEx4LtmBtDl7pXJjGmgkoJcDT6LiMRJ9lVJq0d4/0PAh5IZw0hKC6I8X9ecyhBERNJKWgw+p1LQldSOu6c6FBGRtKDEUJhLW2cPLR2qlyQiAkoMcfd+1gC0iAgoMVDaO/tZcxlERAAlhr5CemoxiIgEMj4x9JXFUItBRARQYiBWEHQl7VOLQUQEUGIgPxqhIBpRV5KISCjjEwMEl6yqK0lEJKDEQDDO0KB6SSIigBIDEIwzaIxBRCSgxEAwyU2F9EREAkoMHO5K6ulRvSQRkVEnBjO70syKwtefNbN7zOy05IU2cWKFuXT1OAfbOlMdiohIyiXSYvhXd28ys3OBtwC3Ad9LTlgTqzSc5KZxBhGRxBJDb/nRtwLfc/dfAtHxD2ni9U5y05VJIiKJJYZXzez7wFXAfWaWm+D+aauvLIYGoEVEEvpivwp4EFjl7q8BJcA/DbeDmd1iZnVmtmmI983M/sPMdprZc6kas+gtvb1PLQYRkYQSwxzg1+6+w8wuAK4E/jjCPrcCq4Z5/xJgcfhYQ4rGLGYWqMUgItIrkcRwN9BtZouAHwALgJ8Ot4O7Pw40DLPJ5cCPPLAemGFmcxKIaVzkRLKYMS1H9ZJEREgsMfS4exfwDuBb7v4PBK2IsZgLvBK3XBOum3CxgqjqJYmIkFhi6DSz1cB7gXvDdTljPL8Nsm7QWWZmtsbMqsysqr6+foynPVKsUGUxREQgscTwfuBs4Evu/oKZLQD+e4znrwGOjVueB+webEN3v8ndK929sqysbIynPVKpCumJiAAJJAZ3rwY+CWw0s+VAjbt/eYznXwu8N7w66SzggLvvGeMxj0qJ6iWJiACQPdoNwyuRbgNeJOgCOtbM3hcOMA+1z+3ABUCpmdUAnyfsfnL3G4H7gEuBnUALQaskJWIFuTS2dNLV3UN2ZEpMzxAROSqjTgzA14E3u/s2ADNbAtwOnD7UDu6+ergDursDH00ghqTpLYvR0NLBrKK8FEcjIpI6ifw0zulNCgDuvp2xDz6njVhhUBZDl6yKSKZLpMVQZWY/AH4cLr8beHr8Q0qNWN8kNyUGEclsiSSGvyHo9vkYwRjD48B3kxFUKvS1GDSXQUQy3KgTg7u3A98IH1NOaaFaDCIiMIrEYGYbGWLSGYC7nzyuEaVIcV4OkSxTi0FEMt5oWgxvS3oUaSAry8K5DGoxiEhmGzExuPtLozmQmf3B3c8ee0ipEyuIqiyGiGS88ZzJNekv/i8tzFVXkohkvPFMDEOOQ0wWMdVLEhGZGrfmHC+xglyNMYhIxhvPxDBYCe1JJVYYpbm9i7bO7lSHIiKSMuOZGN4zjsdKib7Zz+pOEpEMNurEYGZNZnZwwOMVM/u5mS10903JDHQiHK6XpAFoEclciZTE+AbBTXR+StBt9E5gNrANuIWgvPakFtPsZxGRhLqSVrn79929yd0PuvtNwKXuficwM0nxTajSgqDFsE8tBhHJYIkkhh4zu8rMssLHVXHvTfpLVeFwi0GXrIpIJkskMbybYIC5DqgNX19jZvnAdUmIbcJNi0bIy8nS4LOIZLRE7vm8y93/wt1L3b0sfL3T3Vvd/fdD7Wdmq8xsm5ntNLNPD/L+fDN7xMz+bGbPmdmlR/vHjJWZESvIVVeSiGS0RO75XAZ8GKiI38/dPzDMPhHgBuBioAbYYGZr3b06brPPAne5+/fMbBnBfaArEvgbxlWsUIX0RCSzJXJV0i+B3wHrgNHOAFsB7HT3XQBmdgdwORCfGBwoDl9PJ7jyKWViBVHq1WIQkQyWSGKY5u6fSvD4c4FX4pZrgDMHbPMF4Ddm9ndAAbAywXOMq1hhLlv3NqUyBBGRlEpk8Pneo+j/H6xMxsArmFYDt7r7POBS4MdmdkRcZrbGzKrMrKq+vj7BMEYvVhhl/6EO3KfEhVYiIglLJDH8PUFyaA1nPTeZ2cER9qkBjo1bnseRXUUfBO4CcPc/EJTvLh14IHe/yd0r3b2yrKwsgbATUxEroKOrhx11zUk7h4hIOkvkqqQid89y93x3Lw6Xi0fYbQOw2MwWmFmUYLb02gHbvAxcBGBmSwkSQ/KaBCO48IRZAKzbUpuqEEREUmrExGBmJ4bPpw32GG5fd+8imOPwILCF4OqjzWZ2vZldFm72j8CHzexZ4HbgWk9hP87s6XmcPG8666qVGEQkM41m8PkTwBrg64O858CbhtvZ3e8juAQ1ft3n4l5XA+eMIo4Js3JpOd9ct536pnbKinJTHY6IyIQascXg7mvC5wsHeQybFCarlUvLcYdHttalOhQRkQmXyOWqmNkbOHKC24/GOaaUWzqniLkz8vlNdS1XnXHsyDuIiEwhicx8/jFwPPAMhye4OTDlEoOZcfGycu7Y8DKtHd3kRyOpDklEZMIk0mKoBJalcmB4Iq1cWs6tT77IEzv3sXJZearDERGZMInMY9hEcGOejLBiQQlFudm6bFVEMk4iLYZSoNrM/gj0FRNy98uG3mXyimZncf4JZazbUkdPj5OVNdgkbhGRqSeRxPCFZAWRri5eVs69z+3h2ZrXOHX+lLhJnYjIiEadGNz9sWQGko4uWDKLSJbxUHWtEoOIZIxRjzGY2TvMbIeZHUigVtKkNn1aDmcuKNE4g4hklEQGn78KXObu0xOolTTprVxazvbaZl7afyjVoYiITIhEEkOtu29JWiRpauXS4FLVdVs0C1pEMkMiiaHKzO40s9Vht9I7zOwdSYssTcyPTeOE8iIV1RORjJHIVUnFQAvw5rh1DtwzrhGloZXLZnHjY7s40NLJ9Gk5qQ5HRCSpErkq6f3JDCSdrVxazg2PPM8j2+r4y1PnpjocEZGkSqRW0g858racuPsHxjWiNPT6eTMoK8rloS21SgwiMuUl0pV0b9zrPODtHHmbzikpK8tYuXQWv3p2Dx1dPUSzExmaERGZXBK5tefdcY+fAFcBy5MXWnpZubSc5vYunnphf6pDERFJqrH89F0MzB9pIzNbZWbbzGynmX16iG2uMrNqM9tsZj8dQ0xJc86iUvJysnR1kohMeaNKDBboCWc8HwxnPP8K+NQI+0WAG4BLgGXAajNbNmCbxcBngHPc/STg40fxdyRdXk6ENy4u46HqWjKk8riIZKhRJYbwHgzPhDOeex9L3P3uEXZdAex0913u3gHcAVw+YJsPAze4e2N4rrSdSXbx0nJ2H2ijes+UrgQiIhkuka6kJ83sjASPPxd4JW65JlwXbwmwxMyeMLP1ZrYqwXNMmDctnYUZrKtO29wlIjJmiSSGNwHrzex5M3vOzDaa2XMj7DPYTQwG9sNkE4xXXACsBm42sxlHHMhsjZlVmVlVfX19AmGPn9LCXE6bP1NF9URkSkvkctVLjuL4NcCxccvzOPIS1xpgvbt3Ai+Y2TaCRLEhfiN3vwm4CaCysjJlnfwrl5bzlQe2sudAK3Om56cqDBGRpEnkctWXBnuMsNsGYLGZLTCzKPBOYO2AbX4BXAhgZqUEXUu7Rv8nTKyLl80C4LcqqiciU1RSZ2q5exdwHfAgsAW4y903m9n1ZtZ7S9AHgf1mVg08AvyTu6ftZIHjywqpiE3jIV22KiJTVCJdSUfF3e8D7huw7nNxrx34RPhIe2bGm0+azQ+feIHdr7VyzAx1J4nI1KLaDkfhvWcfB8ANj+xMcSQiIuNPieEozJs5javPOJa7ql7hlYaWVIcjIjKulBiO0kcvXISZ8Z2H1WoQkalFieEozZmez7tWzOdnf6rR/aBFZEpRYhiDv73geLKzjG//dkeqQxERGTdKDGMwqziP9559HL/486s8X9+c6nBERMaFEsMYfeT848nNjvDtdWo1iMjUoMQwRqWFubzvDRX86rndbK9tSnU4IiJjpsQwDj5y3kKm5ajVICJTgxLDOJhZEOUD5y7g1xv3UL1b92oQkclNiWGcfOjchRTlZfOtddtTHYqIyJgoMYyT6dNy+NC5C/lNdS0baw6kOhwRkaOmxDCO3n9uBdPzc/imWg0iMokpMYyj4rwc1py3kIe31vHnlxtTHY6IyFFRYhhn73tDBSUFUb7xkFoNIjI5KTGMs8LcbD5y3kJ+t2MfG15sSHU4IiIJU2JIgveeXUFpYS5ffWArXd09qQ5HRCQhSU8MZrbKzLaZ2U4z+/Qw211hZm5mlcmOKdnyoxH+edUJbHixkX/+2XP09HiqQxIRGbWk3trTzCLADcDFQA2wwczWunv1gO2KgI8BTyUznol0VeWx1B5o4+sPbWdaboQvXr4cM0t1WCIiI0p2i2EFsNPdd7l7B3AHcPkg230R+CrQluR4JtR1b1rER85fyH+vf5kvP7CV4PbWIiLpLdmJYS7wStxyTbiuj5mdChzr7vcmOZYJZ2Z8etWJXHPWfL7/2C7d7U1EJoWkdiUBg/Wd9P1sNrMs4JvAtSMeyGwNsAZg/vz54xRe8pkZ11+2nJb27rBbKZsPnrsg1WGJiAwp2S2GGuDYuOV5wO645SJgOfComb0InAWsHWwA2t1vcvdKd68sKytLYsjjLyvL+OoVJ7PqpNl88d5q7tzwcqpDEhEZUrITwwZgsZktMLMo8E5gbe+b7n7A3UvdvcLdK4D1wGXuXpXkuCZcdiSLb68+hfOXlPHpezbyq2d3j7yTiEgKJDUxuHsXcB3wILAFuMvdN5vZ9WZ2WTLPnY5ysyPceM3pnHFcCf9w5zOsq65NdUgiIkewyXilTGVlpVdVTd5GRVNbJ++++Sm27m3ih9eewTmLSlMdkohkADN72t1HnCummc8pUJSXw23vX8GCWAEfuHUDv92iloOIpA8lhhSZWRDlJx8+kyXlRaz58dP87OmaVIckIgIoMaRUaWEut685i7MWlvDJ/3mWmx5/PtUhiYgoMaRaYW42t1x7Bm993Rz+7b6t/Pt9WzRDWkRSKtkT3GQUcrMj/MfqUykpiPL9x3ex/1AHX37H68iOKG+LyMRTYkgTkSzj+stPIlYY5VvrdtB4qIPvvOs08qORVIcmIhlGP0nTiJnx8ZVL+OJfLufhbXW85wdPcaClM9VhiUiGUWJIQ+856zi+s/o0nqs5wFXf/wMv72+hW/d0EJEJoq6kNPXWk+cwY1oOa35UxXlfewSAgmiE4vwcivKyKcrLoTh8LsrL5qyFMf7i9cekOGoRmQo08znN7axr5rHt9TS1ddLU1sXB1uC5qf3wcmNLJwdaO1m94li+cNlJ5GZrXEJEjjTamc9qMaS5RbMKWTSrcNhtunucr/9mG9999Hm27GnixmtOZ/b0vAmKUESmGo0xTAGRLOOfV53I9959Gjtqm3jbf/6eP77QkOqwRGSSUmKYQi553Rx+8dFzKMrL5l3/tZ7bnnxRk+VEJGFKDFPM4vIifnndOVxwQhmfX7uZf/yfZ2nr7E51WCIyiSgxTEHFeTnc9J5KPr5yMff86VWuuPFJahpbUh2WiEwSSgxTVFZWMFnu5vdW8tK+Fi77zhOsfXa3Wg8iMiJdrpoBdtU389f//TTba5uZFo1w4QmzWLV8Nm86cRYFubowTSRTpM3lqma2Cvg2EAFudvcvD3j/E8CHgC6gHviAu7+U7LgyycKyQn79sTfy1K4G7t+0hwc37+XXG/eQm53FeUvKuPR1s7loaTnFeTmpDlVE0kBSWwxmFgG2AxcDNcAGYLW7V8dtcyHwlLu3mNnfABe4+9XDHVcthrHp7nGqXmzg/k17eWDTXvYebCMnYpyzqJSLlpZz9sISji8rxMxSHaqIjKN0aTGsAHa6+64wqDuAy4G+xODuj8Rtvx64JskxZbxIlnHmwhhnLozxubct45ma17h/4x7u37SXR7fVA1BaGOXMBTHOXFjCWQtjLJ6lRCGSKZKdGOYCr8Qt1wBnDrP9B4H7kxqR9JOVZZw2fyanzZ/Jv1y6lJcbWli/az9P7Wpg/a79/HrjHgBKCqKcuaCEMxeUcFysgJxIFjkRIyc7i2gki2h2Vt+6aCSLWGEukSwlEpHJKNmJYbBvhkH7rszsGqASOH+I99cAawDmz58/XvFJHDPjuFgBx8UKuPqM+bg7NY2t/CEuUdy/ae+ojlVamMtbTirnkuVzOHNhCTm66ZDIpJHsMYazgS+4+1vC5c8AuPu/D9huJfCfwPnuXjfScTXGkDo1jS3sb+6go7uHzq6e4Lnb6ejqobM7WG7r7OapXQ08vLWO1s5uZkzL4eKl5Vzyutmcs6hURf5EUmS0YwzJTgzZBIPPFwGvEgw+v8vdN8dtcyrwM2CVu+8YzXGVGCaHts5uHttezwOb9rKuupam9i6KcrO5aOksVi2fw5kLSphZEE11mCIZIy0Gn929y8yuAx4kuFz1FnffbGbXA1Xuvhb4GlAI/E84uPmyu1+WzLhkYuTlRHjLSbN5y0mzae/q5smd+7l/0x4eqq7lF8/sBoIupxNmF7J4VhFLyouC1+VFunRWJIU0wU0mXFd3DxtebGTTqwfYXtvE9tomdtQ109JxeFb27OI8FpcXUpibTY87PQ4+4Lkn/Lf7urnTuXhZOa+fN4MsDXiLDCktupKSRYlh6unpcV59rTVMFM3sCJNFe1c3WWaYGVkGWeEz4XNXt1O95yDdPU5ZUS4rl87i4mXlvOH4UvJyNJYhEk+JQTLGgZZOHtlWx0PVtTy2vZ7m9i7ycyKct6SUlUvLuWhpOSXjMJbR3N7FtJyIWiUyaSkxSEZq7+pm/a4G1lXXsm5LLXsOtJFlsOyYYlZUxFixoIQzKmYSK8wd8Vj1Te2s37WfP+zaz/rn97Nr3yGikSzmzsxnXt9jWt/zsTPzKS3MVeKQtKXEIBnP3dm8+yC/3VLH+l37+dPLjbR39QCweFYhKxaU9D3mTM+n4VBHkAieD5LBzrpmAApzs1mxoITT5s+gqb2LmsZWahpbebWxhX3NHf3OmZudxfFlhZww+/Bg+gmzizlmet6wM8cPtXfx6mutvNrYSs1rrTS3dXHuolKWzy3WjHMZN0oMIgN0dPWw8dXXeOqFBv74QgNVLzbS3N4FBFdH7WtuB2BaNEJlRQlnL4xx9vExlh9TTPYQE/RaOrrY/VorrzS0UtPYwkv7W9hR18z22ib2HGjr264wN5sl5UHCmDdzGvua23m1sTVIBq+18lpL56DHnzsjn4uXlbNq+WzOqCjRbHIZEyUGkRF0dfewdW8TT73QwKZXD3B8WQFnHx/j5HkzxmWm9oHWTnbUNrGttonte4PnbXubaGzppCAaYe7MfI6Zkc/cGfnMnRk8z5uZz9wZ08iJGI9sC+aA/G5HPe1dPZQURFm5NCiZPtLgurvT1hm0jvKjGoSXgBKDSBpyd1o7u8nPiYy6i+hQexePba/nwc17eXhLHU3tXRREI5yxoIQeh9aOLg61d9PS0cWhjm5aO7o51NFF73/tY6bnsai8iCWzCllSXsSi8kIWzyqkSHNFMo4Sg8gU1NHVw5PP7+PBzbX8+eVGcnMiFEQjTItGmBbNpiA3Qn5O+ByN0NPj7KxrZkddMzvrmvvGWOBwwpg7I5/O7h7au4JyJm2d3bR39dAePrd1dtPZHXxP9H5f9H5rxH99RLOzKC/OZVZxHrOL8ygvzqW8OI9ZRYdfj+XGUE1tnbyw7xAvhBcBnH7cTGYV5x318TKREoOI9NPd47zScHgMpHeuyN4DbeRmZ5GbE+l7zsvOIi9czsuJkBPJoreB09vOObwcvGjt7KauqY26g+3sPdjWb8Jir8LcbGYV5zKrqDdpBM9lccs97uyqP9SXBHaFz/VN7Uccb37JNCorZlJ5XHC12fFlhYNeFdbR1cPOuma27DnIlj0Hqd5zkK17m8iJGEvnFPc9ls0poiJWMOSY0mSnxCAiKePuNLd3UXuwnbqDbdQ2tVF7sJ29B9qob2qnLlyuPdjWrxUzUKwgysKyAhaUFrCgtJAFpQUsLCugpaObqheDCwiqXmrouzpsen4OlcfNpLKihJyIUb3nIFv2NLGzrqmv1ZObncWJs4s4cXYxnd09VO85yPP1zf3eP2F2EUtnF7N0ThGxwlzMggQYPPcmxfjlYMKlWe/rIF0GkzOD9bnZvS27oDU3LZpNfk5kQi8oUGIQkbTn7hxs66LuYBt1TUGigOB2tAtiBUyfNvI4iLvz0v4WNoSJYsNLDeyqPwRAeXHuqFoEA1sUW/YGCaXhUMdgpxxXudlZfV2BkSzr1zKzMMEQt+5vL1jEX50+76jOlRZF9EREhmNmTM/PYXp+DovLi476GBWlBVSUFnBl5bEANBzqwN1HNZERgvGRZccUs+yY4r517k59UzsHWjtxgvEUx4PnuNfEvTewppcTlHvp8WDyZWtHNy0d3bR0dtPa0UVLx+F1hzq66Ok5PIYTHHPAuI4zLrP4R6LEICJTznh8eZoZs4rzMnKAe2qOsIiIyFFTYhARkX6UGEREpB8lBhER6SfpicHMVpnZNjPbaWafHuT9XDO7M3z/KTOrSHZMIiIytKQmBjOLADcAlwDLgNVmtmzAZh8EGt19EfBN4CvJjElERIaX7BbDCmCnu+9y9w7gDuDyAdtcDtwWvv4ZcJGpAL2ISMokOzHMBV6JW64J1w26jbt3AQeAWJLjEhGRISR7gttgv/wH1uAYzTaY2RpgTbjYbGbbjjKmUmDfUe6bKop5Yky2mCdbvKCYJ8pQMR83mp2TnRhqgGPjlucBu4fYpsbMsoHpQMPAA7n7TcBNYw3IzKpGUysknSjmiTHZYp5s8YJinihjjTnZXUkbgMVmtsDMosA7gbUDtlkLvC98fQXwsE/Gyn4iIlNEUlsM7t5lZtcBDwIR4BZ332xm1wNV7r4W+AHwYzPbSdBSeGcyYxIRkeElvYieu98H3Ddg3efiXrcBVyY7jjhj7o5KAcU8MSZbzJMtXlDME2VMMU/K+zEDDagcAAAIcklEQVSIiEjyqCSGiIj0k1GJYaTyHOnIzF40s41m9oyZpeVt68zsFjOrM7NNcetKzOwhM9sRPs9MZYzxhoj3C2b2avg5P2Nml6YyxoHM7Fgze8TMtpjZZjP7+3B9Wn7Ow8Sbtp+zmeWZ2R/N7Nkw5v8brl8QluvZEZbvSf6dckZpmJhvNbMX4j7nUxI6bqZ0JYXlObYDFxNcIrsBWO3u1SkNbARm9iJQ6e5pex21mZ0HNAM/cvfl4bqvAg3u/uUwCc9090+lMs5eQ8T7BaDZ3f9/KmMbipnNAea4+5/MrAh4GvhL4FrS8HMeJt6rSNPPOay4UODuzWaWA/we+HvgE8A97n6Hmd0IPOvu30tlrL2GifmvgXvd/WdHc9xMajGMpjyHHAV3f5wj557Elzq5jeBLIS0MEW9ac/c97v6n8HUTsIWgakBafs7DxJu2PNAcLuaEDwfeRFCuB9LoM4ZhYx6TTEoMoynPkY4c+I2ZPR3O/p4syt19DwRfEsCsFMczGteZ2XNhV1NadMkMJqxAfCrwFJPgcx4QL6Tx52xmETN7BqgDHgKeB14Ly/VAGn5vDIzZ3Xs/5y+Fn/M3zWx0N78OZVJiGFXpjTR0jrufRlCh9qNhN4iMv+8BxwOnAHuAr6c2nMGZWSFwN/Bxdz+Y6nhGMki8af05u3u3u59CUKVhBbB0sM0mNqrhDYzZzJYDnwFOBM4ASoCEuhczKTGMpjxH2nH33eFzHfBzgn+sk0Ft2M/c299cl+J4huXuteF/sB7gv0jDzznsQ74b+Im73xOuTtvPebB4J8PnDODurwGPAmcBM8JyPZDG3xtxMa8Ku/Lc3duBH5Lg55xJiWE05TnSipkVhAN3mFkB8GZg0/B7pY34UifvA36ZwlhG1PvlGno7afY5h4OMPwC2uPs34t5Ky895qHjT+XM2szIzmxG+zgdWEoyNPEJQrgfS6DOGIWPeGvdjwQjGRBL6nDPmqiSA8NK4b3G4PMeXUhzSsMxsIUErAYJZ6j9Nx5jN7HbgAoKKjrXA54FfAHcB84GXgSvdPS0GfIeI9wKC7g0HXgQ+0tt3nw7M7Fzgd8BGoCdc/S8E/fZp9zkPE+9q0vRzNrOTCQaXIwQ/mu9y9+vD/4d3EHTJ/Bm4JvwlnnLDxPwwUEbQhf4M8Ndxg9QjHzeTEoOIiIwsk7qSRERkFJQYRESkHyUGERHpR4lBRET6UWIQEZF+lBgkrZjZZTZC5VszO8bMBi0OZmaPmtmo73VrZqeMpsKnmY36Ur8RjjPHzO4dj2MNcuyvmdnWsAzCz3uvbw/f+4wFVYW3mdlb4tYPWnF4qIqiZnadmb0/GfFL+lBikLTi7mvd/csjbLPb3a8YbpsEnAJMZOnnTxDM+B2zsGJwvIeA5e5+MkEl4c+E2y0jmNB5ErAK+G5YXycC3EBQbmUZsDrcFuArwDfdfTHQCHwwXH8L8LHxiF/SlxKDTAgzqwh/zd5sZpvM7CdmttLMngh/la4It7vWzL4Tvr7VzP7DzJ40s11mdkXcsYabyXlNuM+muOOuCNf9OXw+IfwVfD1wtQU16682s0Iz+6EF98B4zsz+Ku5v+JIFde/Xm1l5uK7MzO42sw3h45xw/fl2uBb+n3tnsAN/BTwQ97f+0sweCH+1fz7uXNdYUGf/GTP7fm8SMLNmM7vezJ4Czo7/o939N3HF3tYTlG+AoALrHe7e7u4vADsJSiQMWnE4nC07aEVRd28BXuz9XGVqUmKQibQI+DZwMkGBr3cB5wKfJJgVO5g54TZvA4ZtScQpcPc3AH9L8AsXYCtwnrufCnwO+Lfwy/BzwJ3ufoq73wn8K3DA3V8X/vJ+uPeYwHp3fz3wOPDhcP23CX5Zn0HwpX9zuP6TwEfD4mZvBFrNbAHQOGDW7Arg3QQtlyvNrNLMlgJXExRQPAXoDrfpjWOTu5/p7r8f5jP4AHB/+HqoysJDrY8xfEXRqvBvkikqe+RNRMbNC+6+EcDMNgO/dXc3s41AxRD7/CIsuFbd+yt9FG6H4L4LZlYc9rUXAbeZ2WKCcgw5Q+y7kqDbhfAYjeHLDqB3bOBpghs+9W6/LPiRDUBx2Dp4AviGmf2E4CYvNWH9mvoB53vI3fcDmNk9BEmwCzgd2BAeN5/DxfG6CQrTDcnM/k94jJ/0rhpkM2fwH4Y+zPa96ggSu0xRSgwykeJ/KffELfcw9L/F+H2O+MIysx8S1Prf7e69YwUD67w48EXgEXd/uwX3B3h0iPPZIPsDdPrh+jHdcfFmAWe7e+uA7b9sZr8mGL9Yb2YrgVYgb5DYBi4bcJu7f2aQONrcvXuI2DGz9xG0ri6Ki3e4ysKDrd9HWFE0bDUMrCiaF/4tMkWpK0kmNXd/f9gNFD+AfDX0FXI74O4HgOnAq+H718Zt20TQmuj1G+C63gUb+UYyA7c/JXw+3t03uvtXCLpeTiQYEK4YsP/FFty3OZ+gH/8J4LfAFWY2KzxWiZkdN0IcmNkqgrr7l4VjAb3WAu80s9ywO2sx8EeGqDgcJpThKoouIY2qosr4U2KQqajRzJ4EbuTw1TRfBf7dzJ4gqETZ6xGCrqBnzOxq4P8BM8OB62eBC0c418eAynCguprgXrsAH487Ritwv7sfAp43s0Vx+/8e+DFBBcy73b3Kg/uQf5bgzn3PEVxtFF+ueijfIUhyD4V/z40A7r6ZoAJrNcHA90fDeyJ0ESS1BwnKS98VbgtBgvmEme0kGHP4Qdx5zgHWjSIemaRUXVVkApnZ24HT3f2zZnYtUOnu142wW9ows1OBT7j7e1IdiySPxhhEJpC7/9zMYqmOYwxKCa7ckilMLQYREelHYwwiItKPEoOIiPSjxCAiIv0oMYiISD9KDCIi0o8Sg4iI9PO/YKQ6WoW1WIMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(cost)\n",
    "plt.xlabel('mini-batches(per 2000)')\n",
    "plt.ylabel('running_loss')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the network on the 60000 test images: 96.355 %\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "with torch.no_grad():\n",
    "    for data in trainloader:\n",
    "        images, labels = data\n",
    "        images = images.view(-1, 28, 28)\n",
    "        outputs = net(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "print('Accuracy of the network on the 60000 test images: %.3f %%' % \n",
    "     (100 * correct / total))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the network on the 10000 test images: 95.790 %\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        images = images.view(-1, 28, 28)\n",
    "        outputs = net(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "print('Accuracy of the network on the 10000 test images: %.3f %%' % \n",
    "     (100 * correct / total))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
